require(HonestDiD)

# Load necessary data

load("00_data/01_data_processed/data_epa_analysis.Rdata")

# Set treatment year for analysis
epa_only$treatment_year_reg <- ifelse(epa_only$treatment_region == 0, 0, 1995)
epa_only$treatment_year_sta <- ifelse(epa_only$treatment_state == 0, 0, 1995)

# Convert PGM_SYS_ID to a numeric identifier
epa_only$PGM_SYS_ID_num <- as.numeric(as.factor(epa_only$PGM_SYS_ID))


honest_did <- function(...) UseMethod("honest_did") # See here: https://github.com/asheshrambachan/HonestDiD

#' @title honest_did.AGGTEobj
#'
#' @description a function to compute a sensitivity analysis
#'  using the approach of Rambachan and Roth (2021) when
#'  the event study is estimating using the `did` package
#'
#' @param es Result from aggte (object of class AGGTEobj).
#' @param e event time to compute the sensitivity analysis for.
#'  The default value is `e=0` corresponding to the "on impact"
#'  effect of participating in the treatment.
#' @param type Options are "smoothness" (which conducts a
#'  sensitivity analysis allowing for violations of linear trends
#'  in pre-treatment periods) or "relative_magnitude" (which
#'  conducts a sensitivity analysis based on the relative magnitudes
#'  of deviations from parallel trends in pre-treatment periods).
#' @param gridPoints Number of grid points used for the underlying test
#'  inversion. Default equals 100. User may wish to change the number of grid
#'  points for computational reasons.
#' @param ... Parameters to pass to `createSensitivityResults` or
#'  `createSensitivityResults_relativeMagnitudes`.
honest_did.AGGTEobj <- function(es,
                                e          = 0,
                                type       = c("smoothness", "relative_magnitude"),
                                gridPoints = 100,
                                ...) {
  
  type <- match.arg(type)
  
  # Make sure that user is passing in an event study
  if (es$type != "dynamic") {
    stop("need to pass in an event study")
  }
  
  # Check if used universal base period and warn otherwise
  if (es$DIDparams$base_period != "universal") {
    stop("Use a universal base period for honest_did")
  }
  
  # Recover influence function for event study estimates
  es_inf_func <- es$inf.function$dynamic.inf.func.e
  
  # Recover variance-covariance matrix
  n <- nrow(es_inf_func)
  V <- t(es_inf_func) %*% es_inf_func / n / n
  
  # Check time vector is consecutive with referencePeriod = -1
  referencePeriod <- -1
  consecutivePre  <- !all(diff(es$egt[es$egt <= referencePeriod]) == 1)
  consecutivePost <- !all(diff(es$egt[es$egt >= referencePeriod]) == 1)
  if ( consecutivePre | consecutivePost ) {
    msg <- "honest_did expects a time vector with consecutive time periods;"
    msg <- paste(msg, "please re-code your event study and interpret the results accordingly.", sep="\n")
    stop(msg)
  }
  
  # Remove the coefficient normalized to zero
  hasReference <- any(es$egt == referencePeriod)
  if ( hasReference ) {
    referencePeriodIndex <- which(es$egt == referencePeriod)
    V    <- V[-referencePeriodIndex,-referencePeriodIndex]
    beta <- es$att.egt[-referencePeriodIndex]
  } else {
    beta <- es$att.egt
  }
  
  nperiods <- nrow(V)
  npre     <- sum(1*(es$egt < referencePeriod))
  npost    <- nperiods - npre
  if ( !hasReference & (min(c(npost, npre)) <= 0) ) {
    if ( npost <= 0 ) {
      msg <- "not enough post-periods"
    } else {
      msg <- "not enough pre-periods"
    }
    msg <- paste0(msg, " (check your time vector; note honest_did takes -1 as the reference period)")
    stop(msg)
  }
  
  baseVec1 <- basisVector(index=(e+1),size=npost)
  orig_ci  <- constructOriginalCS(betahat        = beta,
                                  sigma          = V,
                                  numPrePeriods  = npre,
                                  numPostPeriods = npost,
                                  l_vec          = baseVec1)
  
  if (type=="relative_magnitude") {
    robust_ci <- createSensitivityResults_relativeMagnitudes(betahat        = beta,
                                                             sigma          = V,
                                                             numPrePeriods  = npre,
                                                             numPostPeriods = npost,
                                                             l_vec          = baseVec1,
                                                             gridPoints     = gridPoints,
                                                             method ="Conditional",
                                                             ...)
    
  } else if (type == "smoothness") {
    robust_ci <- createSensitivityResults(betahat        = beta,
                                          sigma          = V,
                                          numPrePeriods  = npre,
                                          numPostPeriods = npost,
                                          l_vec          = baseVec1,
                                          ...)
  }
  
  return(list(robust_ci=robust_ci, orig_ci=orig_ci, type=type))
}

######
set.seed(123)


# Main analysis for EPA inspections using regional treatment year
set.seed(123)
model_epa_region <- att_gt(
  yname = "outcome_epa",
  tname = "year",
  idname = "PGM_SYS_ID_num",
  gname = "treatment_year_reg",
  data = epa_only[epa_only$year < 2002 & epa_only$year > 1989, ],
  alp = .01,
  base_period ="universal"
)

es <- did::aggte(model_epa_region, type = "dynamic",
                 min_e = -5, max_e = 6)

es

set.seed(123)
sensitivity_results_e1 <- honest_did(es,
                                     type="relative_magnitude",
                                     e=1,
                                     Mbarvec=seq(from = 0.5, to = 2, by = 0.1),
                                     alpha=.05)

set.seed(123)
sensitivity_results_e2 <- honest_did(es,
                                     type="relative_magnitude",
                                     e=2,
                                     Mbarvec=seq(from = 0.5, to = 2, by = 0.1),
                                     alpha=.05)

set.seed(123)
sensitivity_results_e3 <- honest_did(es,
                                     type="relative_magnitude",
                                     e=3,
                                     Mbarvec=seq(from = 0.5, to = 2, by = 0.1),
                                     alpha=.05)




p1 <- createSensitivityPlot_relativeMagnitudes(sensitivity_results_e1$robust_ci,
                                               sensitivity_results_e1$orig_ci)+
  theme_bw()+
  theme(legend.position = "bottom",
        legend.title = element_blank())+
  scale_color_manual(values=c("blue","black"))+
  ggtitle("CI for t = 1")

p2 <- createSensitivityPlot_relativeMagnitudes(sensitivity_results_e2$robust_ci,
                                               sensitivity_results_e2$orig_ci)+
  theme_bw()+
  theme(legend.position = "bottom",
        legend.title = element_blank())+
  scale_color_manual(values=c("blue","black"))+
  ggtitle("CI for t = 2")

p3 <- createSensitivityPlot_relativeMagnitudes(sensitivity_results_e3$robust_ci,
                                               sensitivity_results_e3$orig_ci)+
  theme_bw()+
  theme(legend.position = "bottom",
        legend.title = element_blank())+
  scale_color_manual(values=c("blue","black"))+
  ggtitle("CI for t = 3")

pdf("03_figures/appendix/figure_a8.pdf", width = 12, height = 3)
print(ggpubr::ggarrange(p1,p2,p3, ncol=3,
                  common.legend=T,
                  legend="bottom")+
  ggtitle("Effect T=3"))
dev.off()

rm(list = ls())
